#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ======================
#  crash-safe logging
# ======================
import os, sys, re, math, difflib, random, argparse, datetime, signal, traceback, faulthandler
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.distributed as dist
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

OUT_DIR = "runs_dp"
RUN_NAME = "unified_run"

def _rank_log_path(rank):
    os.makedirs(OUT_DIR, exist_ok=True)
    return os.path.join(os.path.abspath(OUT_DIR), f"{RUN_NAME}.rank{rank}.log")

def _log_exc(rank, where, extra=""):
    try:
        with open(_rank_log_path(rank), "a") as f:
            f.write(f"\n[{datetime.datetime.now()}] RANK {rank} EXCEPTION @ {where}\n")
            if extra: f.write(extra + "\n")
            f.write(traceback.format_exc() + "\n")
            f.flush()
    except Exception:
        pass  # last-resort logging

def _install_signal_dumps(rank):
    logf = open(_rank_log_path(rank), "a", buffering=1)
    faulthandler.enable(file=logf)
    def _dump(signum, frame):
        logf.write(f"\n[{datetime.datetime.now()}] RANK {rank} got signal {signum}, dumping stacks...\n")
        faulthandler.dump_traceback(file=logf, all_threads=True)
        logf.flush()
        raise SystemExit(1)
    signal.signal(signal.SIGTERM, _dump)
    signal.signal(signal.SIGINT, _dump)

# ==========================
#     DEFAULT MODELS (Qwen-style)
# ==========================
DEFAULT_GEN_MODEL = "Qwen/Qwen3-1.7B"
DEFAULT_CLS_MODEL = "Qwen/Qwen3-1.7B"  # by default reuse the same model

# ==========================
#   PROMPTS (dataset-specific)
# ==========================

PROMPT_GPQA = (
    "You are answering a multiple-choice question.\n"
    "Options are labeled A, B, C, and D.\n"
    "Think step-by-step and show your reasoning.\n"
    "At the very end, output ONE line exactly in this format:\n"
    "Final Answer: \\boxed{A}\n"
    "where the letter is A, B, C, or D.\n"
)

PROMPT_AIME_TIGER = (
    "Answer the following question step-by-step.\n"
    "At the very end, output exactly one line formatted as:\n"
    "Final Answer: \\boxed{...}\n"
)

PROMPT_MATH500 = (
    "Answer the following question step-by-step.\n"
    "At the very end, output exactly one line formatted as:\n"
    "Final Answer: \\boxed{...}\n"
)

# Tagging (optional)
TAG_LIST = [
    "final_answer",
    "setup_and_retrieval",
    "analysis_and_computation",
    "uncertainty_and_verification",
]
FEW_SHOT_PREFIX_4 = """You are an expert in reasoning analysis. 
Classify the function of each sentence into one of the following tags:
1. final_answer
2. setup_and_retrieval
3. analysis_and_computation
4. uncertainty_and_verification
"""

LETTERS = ["A", "B", "C", "D"]
LETTER_RE = re.compile(r"(?i)(?:Final Answer\s*:\s*)?(?:\\boxed\{|\b)([A-D])(?:\}|\.|\b)")

# ==========================
#   DISTRIBUTED HELPERS
# ==========================
def ddp_init():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world = dist.get_world_size()
    torch.cuda.set_device(rank % torch.cuda.device_count())
    return rank, world, torch.device(f"cuda:{rank}")

# ==========================
#   DETERMINISM HELPERS
# ==========================
def set_global_seed(seed: int, rank: int = 0):
    full_seed = (seed if seed is not None else 0) + int(rank)
    random.seed(full_seed)
    np.random.seed(full_seed)
    torch.manual_seed(full_seed)
    torch.cuda.manual_seed_all(full_seed)
    os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
    try:
        torch.use_deterministic_algorithms(False)
    except Exception:
        pass
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False
    return full_seed

def reseed_for_sample(base_seed: int, rank: int, local_idx: int) -> int:
    """
    Make do_sample reproducible per sample & rank, mirroring code2.
    Re-seeds Python, NumPy, and PyTorch RNGs deterministically for each item.
    """
    s = (base_seed if base_seed is not None else 0)
    s = (s * 1_000_003 + rank * 97_931 + local_idx) % (2**31 - 1)
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    return s

# ==========================
#   TEXT ACCURACY HELPERS
# ==========================
def _pick_answer(text: str) -> str:
    m = re.search(r'\\boxed\{(.+?)\}', text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()
    m = re.search(r'(?i)^\s*final\s*answer\s*:\s*(.+)$', text, flags=re.M)
    if m:
        return m.group(1).strip()
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    return lines[-1] if lines else text.strip()

def _to_number_maybe(s: str):
    s = s.strip().replace(',', '')
    if re.fullmatch(r'[+-]?\d+/\d+', s):
        a, b = s.split('/')
        try: return float(a) / float(b)
        except: return None
    if re.search(r'\d', s):
        try: return float(s)
        except: return None
    return None

def _norm_text(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r'\s+', ' ', s)
    return s

def compare_answers(pred_text: str, gold_text: str, atol: float = 1e-6) -> bool:
    pred = _pick_answer(pred_text)
    gold = _pick_answer(gold_text)
    pnum, gnum = _to_number_maybe(pred), _to_number_maybe(gold)
    if (pnum is not None) and (gnum is not None):
        return math.isclose(pnum, gnum, rel_tol=0.0, abs_tol=atol)
    return _norm_text(pred) == _norm_text(gold)

# ==========================
#  SEGMENTATION + ALIGNMENT
# ==========================
def preclean_generated_text(s: str) -> str:
    s = s.replace("\r\n", "\n")
    s = re.sub(r'^\s*answer the following[^\n]*:\s*\n+', '', s, flags=re.I)
    return s

def paragraph_spans(text: str) -> List[Tuple[int,int]]:
    # Split steps by blank paragraphs (\n\n+)
    spans, pos = [], 0
    for m in re.finditer(r"\n{2,}", text):
        if m.start() > pos:
            spans.append((pos, m.start()))
        pos = m.end()
    if pos < len(text):
        spans.append((pos, len(text)))
    return spans

def charspan_to_tokenspan(offsets, a, b):
    i, n = 0, len(offsets)
    while i < n and (offsets[i][1] <= a or offsets[i] == (0,0)):
        i += 1
    t0 = i
    while i < n and (offsets[i][0] < b or offsets[i] == (0,0)):
        i += 1
    t1 = i
    return t0, t1

# ==========================
#     QWEN CHAT HELPERS
# ==========================
def qwen_build_prompt(tokenizer, user_text: str, enable_thinking: bool) -> str:
    """
    Build chat prompt using Qwen chat template.
    enable_thinking=True for generation; False for classification.
    """
    messages = [{"role": "user", "content": user_text}]
    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking
    )

# ==========================
#        CLASSIFIER
# ==========================
def classify_step_4tags_qwen(tok, model, text: str) -> str:
    prompt_text = qwen_build_prompt(
        tok,
        user_text=(
            f"{FEW_SHOT_PREFIX_4}\n\n"
            f"Now classify the following sentence with ONE label only.\n"
            f"Sentence: \"{text}\"\n"
            f"Choose one label from: {', '.join(TAG_LIST)}\n"
            f"Label:"
        ),
        enable_thinking=False,
    )
    inputs = tok(prompt_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            eos_token_id=tok.eos_token_id,
            pad_token_id=tok.pad_token_id,
        )
    gen_txt = tok.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    label_line = gen_txt.strip().split("\n")[-1].lower().replace("label:", "").strip()
    match = difflib.get_close_matches(label_line, TAG_LIST, n=1, cutoff=0.0)
    return match[0] if match else TAG_LIST[0]

# ==========================
#          MAIN
# ==========================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task",
        choices=["tiger", "aime24", "math500", "gpqa_diamond"],
        required=True)
    parser.add_argument("--gen_model", default=DEFAULT_GEN_MODEL)
    parser.add_argument("--cls_model", default=DEFAULT_CLS_MODEL,
                        help="If same as --gen_model (default), reuse the same Qwen model for classification.")
    parser.add_argument("--out_dir", default="runs_dp")
    parser.add_argument("--run_name", default=None)
    parser.add_argument("--num_samples", type=int, default=None)
    parser.add_argument("--max_new_tokens", type=int, default=4000)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=20)
    parser.add_argument("--min_p", type=float, default=0)
    parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
    parser.add_argument("--classify_steps", action="store_true")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--shuffle_dataset", action="store_true",
                    help="Shuffle dataset order deterministically with --seed.")
    parser.add_argument("--tiger_answer_types", nargs="*", default=["Float","Multiple Choice","Integer","Percentage"])
    parser.add_argument("--tiger_difficulties", nargs="*", default=["Senior High School","Junior High School","Primary School"])
    args = parser.parse_args()

    global OUT_DIR, RUN_NAME
    OUT_DIR = args.out_dir
    RUN_NAME = args.run_name or f"{args.task}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"

    # ===== DDP init & per-rank signal dump hooks =====
    rank, world, device = ddp_init()
    _install_signal_dumps(rank)
    if rank == 0:
        os.makedirs(OUT_DIR, exist_ok=True)

    # ===== Seeds / determinism =====
    set_global_seed(args.seed, rank)
    gpqa_seed = args.seed

    # ===== Perf toggles =====
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # ===== DTYPE =====
    dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
    DTYPE = dtype_map[args.dtype]

    # ===== Qwen tokenizer/model for generation =====
    gen_tok = AutoTokenizer.from_pretrained(args.gen_model, trust_remote_code=True)
    if gen_tok.pad_token is None:
        # Qwen models often need pad_token set explicitly
        gen_tok.pad_token = gen_tok.eos_token

    gen_model = AutoModelForCausalLM.from_pretrained(
        args.gen_model, dtype=DTYPE, trust_remote_code=True, device_map=None, attn_implementation="flash_attention_2"
    ).to(device).eval()

    # ===== Classifier model/tokenizer (reuse gen by default) =====
    if args.classify_steps:
        if args.cls_model and (args.cls_model != args.gen_model):
            cls_tok = AutoTokenizer.from_pretrained(args.cls_model, trust_remote_code=True)
            if cls_tok.pad_token is None:
                cls_tok.pad_token = cls_tok.eos_token
            cls_model = AutoModelForCausalLM.from_pretrained(
                args.cls_model, dtype=DTYPE, trust_remote_code=True, device_map=None, attn_implementation="flash_attention_2"
            ).to(device).eval()
        else:
            cls_tok, cls_model = gen_tok, gen_model
        try:
            ccfg = cls_model.generation_config
            ccfg.do_sample = False; ccfg.top_p = None; ccfg.top_k = None; ccfg.temperature = None
        except Exception:
            pass
    else:
        cls_tok = cls_model = None

    records, num_correct = [], 0

    def _last_nonempty_line(s: str) -> str:
        for line in reversed(s.splitlines()):
            line = line.strip()
            if line:
                return line
        return ""

    def _pick_letter(text: str):
        m = LETTER_RE.search(text)
        return m.group(1).upper() if m else None

    # ===== Unified single-sample generate+collect (Qwen prompt path) =====
    def generate_and_collect(
        user_gen_text: str,
        question: str,
        gold_answer: str,
        label_note: str,
        is_mc: bool = False,
        extra_mc: Optional[Dict] = None,
        sample_idx: int = 0
    ):
        """
        user_gen_text: the *user* content that will be wrapped using Qwen chat template.
        extra_mc:
          - for gpqa_diamond: dict with options_shuffled (map A-D->text), correct_letter
        """
        nonlocal num_correct
        reseed_for_sample(args.seed, rank, sample_idx)

        # ---- Build Qwen chat prompt (generation)
        prompt_text = qwen_build_prompt(gen_tok, user_text=user_gen_text, enable_thinking=True)
        gen_inputs = gen_tok(prompt_text, return_tensors="pt").to(device)

        with torch.no_grad():
            out = gen_model.generate(
                **gen_inputs,
                max_new_tokens=args.max_new_tokens,
                do_sample=True, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
                eos_token_id=gen_tok.eos_token_id,
                pad_token_id=gen_tok.pad_token_id,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

        # Separate prompt vs generation
        seq = out.sequences[0]
        inp_len = gen_inputs["input_ids"].shape[1]
        gen_only_ids = seq[inp_len:]
        generated_text = gen_tok.decode(gen_only_ids, skip_special_tokens=True)
        generated_text = preclean_generated_text(generated_text)

        # Hidden states DURING generation -> [L+1, T_gen, D]
        hs = out.hidden_states
        if isinstance(hs[0], tuple):  # per-token style: T_gen x (L+1) x [1,1,D]
            Lp1 = len(hs[0])
            per_layer = []
            for li in range(Lp1):
                toks = [hs[t][li] for t in range(len(hs))]
                per_layer.append(torch.cat(toks, dim=1))       # [1, T_gen, D]
            hidden_states = torch.stack([x[0] for x in per_layer], dim=0).cpu()  # [L+1, T_gen, D]
        else:  # per-layer style: (L+1) x [1, T_gen, D]
            hidden_states = torch.stack([x[0] for x in hs], dim=0).cpu()

        # Step segmentation + offset alignment over generated_text
        enc = gen_tok(generated_text, return_offsets_mapping=True, add_special_tokens=False)
        offsets = enc.offset_mapping

        char_spans = []
        for a, b in paragraph_spans(generated_text):
            para = generated_text[a:b]
            if para.strip():
                char_spans.append((a, b))
        if not char_spans:
            char_spans = [(0, len(generated_text))]

        def _charspan_to_tokenspan(a, b):
            # On the generated_text only (not including prompt). Map to token indices (t0,t1) in gen-only space.
            if offsets is None:
                return 0, 0
            off = offsets
            if isinstance(off, torch.Tensor):
                off = off.tolist()
            if isinstance(off[0], list):
                off = off[0]
            off = [tuple(x) for x in off]

            i, n = 0, len(off)
            while i < n and (off[i][1] <= a or off[i] == (0,0)):
                i += 1
            t0 = i
            while i < n and (off[i][0] < b or off[i] == (0,0)):
                i += 1
            t1 = i
            return t0, t1

        token_spans = [_charspan_to_tokenspan(a, b) for (a, b) in char_spans]

        # Aggregate per-step hidden vector (FIRST token)
        step_hidden_states = []
        sentences_with_labels = []
        MAX_CLASSIFY_CHARS = 600

        for (a, b), (t0, t1) in zip(char_spans, token_spans):
            raw = generated_text[a:b]
            clean = re.sub(r"\s+", " ", raw).strip()

            if t1 > t0:
                # pick first token in this step
                step_slice = hidden_states[:, t0:t1, :]   # [L+1, T_step, D]
                step_hidden = step_slice[:, 0, :]         # [L+1, D]
                step_hidden_states.append(step_hidden.cpu())
            else:
                step_hidden_states.append(None)

            if clean:
                if cls_tok is not None and cls_model is not None and args.classify_steps:
                    tag = classify_step_4tags_qwen(cls_tok, cls_model, clean[:MAX_CLASSIFY_CHARS])
                else:
                    tag = None
                sentences_with_labels.append((clean, tag))

        # Accuracy logic
        if not is_mc:
            is_correct = compare_answers(generated_text, gold_answer) if gold_answer else False
        else:
            final_line = _last_nonempty_line(generated_text)
            pred_letter = None
            pred_letter = _pick_letter(final_line)
            gold_letter = extra_mc.get("correct_letter") if extra_mc else None

            is_correct = (pred_letter == gold_letter) if (pred_letter and gold_letter) else False

        # Save one record (no char/token spans as requested)
        rec = {
            "label_note": label_note,
            "prompt": prompt_text,             # store full chat prompt text (Qwen style)
            "generated_text": generated_text,
            "question": question,
            "ground_truth_answer": gold_answer,
            "is_correct": is_correct,
            "sentences_with_labels": sentences_with_labels,
            "step_hidden_states": step_hidden_states,  # list[(L+1, D)] or None
            "gen_token_count": int(gen_only_ids.numel()),
            "tag_list": TAG_LIST,
        }

        if is_mc and extra_mc:
            if "options_shuffled" in extra_mc:
                rec["options_shuffled"] = extra_mc["options_shuffled"]
            if "correct_letter" in extra_mc:
                rec["ground_truth_answer"] = extra_mc["correct_letter"]
            if "pred_letter" in extra_mc:
                rec["predicted_letter"] = extra_mc["pred_letter"]
            else:
                rec["predicted_boxed"] = re.findall(r'\\boxed\{(.+?)\}', generated_text, flags=re.DOTALL)[:1]
        else:
            rec["predicted_boxed"] = re.findall(r'\\boxed\{(.+?)\}', generated_text, flags=re.DOTALL)[:1]

        records.append(rec)
        num_correct += int(is_correct)

    # ===== Load + run per task =====
    try:
        task = args.task

        if task == "tiger":
            ds_full = load_dataset("TIGER-Lab/WebInstruct-verified", split="test")
            ats = set(args.tiger_answer_types)
            diffs = set(args.tiger_difficulties)
            ANSWER_TYPE_KEY = "answer_type"

            # 1) filter
            if ANSWER_TYPE_KEY in ds_full.column_names and "difficulty" in ds_full.column_names:
                ds_full = ds_full.filter(lambda x: x[ANSWER_TYPE_KEY] in ats and x["difficulty"] in diffs)
            elif ANSWER_TYPE_KEY in ds_full.column_names:
                ds_full = ds_full.filter(lambda x: x[ANSWER_TYPE_KEY] in ats)
            elif "difficulty" in ds_full.column_names:
                ds_full = ds_full.filter(lambda x: x["difficulty"] in diffs)
            else:
                print("[WARN] Columns 'answer_type' and 'difficulty' not found; dataset may be empty.")
                ds_full = ds_full.select([])

            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)
            if rank == 0:
                print(f"[INFO] After filtering, dataset has {total} items. Taking first {take}, then shuffling (if enabled).")

            # 2) cut first N
            ds_subset = ds_full.select(range(take))

            # 3) shuffle slice (deterministic)
            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)

            # 4) shard after shuffle
            ds = ds_subset.shard(num_shards=world, index=rank)

            for idx, sample in enumerate(tqdm(ds, disable=(rank != 0))):
                try:
                    question = str(sample["question"]).strip()
                    gold_answer = str(sample["answer"]).strip()
                    prompt = f"{PROMPT_GSM8K_AIME_TIGER}\n{question}\n"
                    generate_and_collect(prompt, question, gold_answer,
                                        label_note="tiger_filtered", is_mc=False, sample_idx=idx)
                except Exception as e:
                    _log_exc(rank, "tiger_filtered-per-sample",
                            f"idx={idx} | q[:120]={str(sample.get('question','')).strip()[:120]}")
                    records.append({"error": f"{type(e).__name__}: {e}", "sample_idx": idx})
                    continue

        elif task == "aime24":
            ds_full = load_dataset("HuggingFaceH4/aime_2024", split="train")
            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)
            if rank == 0 and take < (args.num_samples or total):
                print(f"[INFO] Requested {args.num_samples}, but dataset only has {total}. Taking first {take}, then shuffling (if enabled).")

            ds_subset = ds_full.select(range(take))
            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)
            ds = ds_subset.shard(num_shards=world, index=rank)

            for idx, sample in enumerate(tqdm(ds, disable=(rank != 0))):
                try:
                    question = str(sample["problem"]).strip()
                    gold_answer = str(sample["answer"]).strip()
                    prompt = f"{PROMPT_GSM8K_AIME_TIGER}\n{question}\n"
                    generate_and_collect(prompt, question, gold_answer,
                                        label_note="aime24", is_mc=False, sample_idx=idx)
                except Exception as e:
                    _log_exc(rank, "aime24-per-sample",
                            f"idx={idx} | q[:120]={str(sample.get('problem','')).strip()[:120]}")
                    records.append({"error": f"{type(e).__name__}: {e}", "sample_idx": idx})
                    continue

        elif task == "math500":
            ds_full = load_dataset("HuggingFaceH4/MATH-500", split="test")
            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)
            if rank == 0 and take < (args.num_samples or total):
                print(f"[INFO] Requested {args.num_samples}, but dataset only has {total}. Taking first {take}, then shuffling (if enabled).")

            ds_subset = ds_full.select(range(take))
            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)
            ds = ds_subset.shard(num_shards=world, index=rank)

            for idx, sample in enumerate(tqdm(ds, disable=(rank != 0))):
                try:
                    question = str(sample["problem"]).strip()   # no extra prompt beyond text
                    gold_answer = str(sample["solution"]).strip()
                    prompt = f"{PROMPT_MATH500}\n{question}\n"     # PROMPT_MATH500 is ""
                    generate_and_collect(prompt, question, gold_answer,
                                        label_note="math500", is_mc=False, sample_idx=idx)
                except Exception as e:
                    _log_exc(rank, "math500-per-sample",
                            f"idx={idx} | q[:120]={str(sample.get('problem','')).strip()[:120]}")
                    records.append({"error": f"{type(e).__name__}: {e}", "sample_idx": idx})
                    continue

        elif task == "gpqa_diamond":
            ds_full = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)
            if rank == 0 and take < (args.num_samples or total):
                print(f"[INFO] Requested {args.num_samples}, but dataset only has {total}. Taking first {take}, then shuffling (if enabled).")

            ds_subset = ds_full.select(range(take))
            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)
            ds = ds_subset.shard(num_shards=world, index=rank)

            for idx, sample in enumerate(tqdm(ds, disable=(rank != 0))):
                try:
                    needed = ["Question","Correct Answer","Incorrect Answer 1","Incorrect Answer 2","Incorrect Answer 3"]
                    for k in needed:
                        if k not in sample:
                            raise ValueError(f"Missing column '{k}' in sample.")

                    # per-sample deterministic shuffle of options
                    rng = random.Random((gpqa_seed if gpqa_seed is not None else 0) + idx)

                    q = str(sample["Question"]).strip()
                    opts = [
                        str(sample["Correct Answer"]).strip(),
                        str(sample["Incorrect Answer 1"]).strip(),
                        str(sample["Incorrect Answer 2"]).strip(),
                        str(sample["Incorrect Answer 3"]).strip(),
                    ]
                    idxs = [0,1,2,3]; rng.shuffle(idxs)
                    shuf = [opts[i] for i in idxs]
                    correct_idx = idxs.index(0)
                    correct_letter = LETTERS[correct_idx]

                    options_block = "\n".join(f"{LETTERS[i]}. {shuf[i]}" for i in range(4))
                    prompt = f"{PROMPT_GPQA_ARC}\n{q}\n\n{options_block}\n"

                    extra = {"options_shuffled": {LETTERS[i]: shuf[i] for i in range(4)},
                            "correct_letter": correct_letter}
                    generate_and_collect(prompt, q, correct_letter,
                                        label_note="gpqa_diamond", is_mc=True, extra_mc=extra, sample_idx=idx)
                except Exception as e:
                    _log_exc(rank, "gpqa-per-sample", f"idx={idx} | keys={list(sample.keys())[:8]}")
                    records.append({"error": f"{type(e).__name__}: {e}", "sample_idx": idx})
                    continue

        # ===== Save shard =====
        os.makedirs(OUT_DIR, exist_ok=True)
        shard_path = os.path.join(OUT_DIR, f"{RUN_NAME}.rank{rank}.pt")
        torch.save({"records": records, "accuracy_partial": (num_correct, len(records))}, shard_path)

    except Exception:
        r = dist.get_rank() if dist.is_initialized() else 0
        _log_exc(r, "main-fatal")
    finally:
        if dist.is_initialized():
            try:
                dist.destroy_process_group()
            except Exception:
                pass
        if (not dist.is_initialized()) or (rank == 0):
            print(f"[Rank 0] Shards/logs saved to {OUT_DIR}. Merge with merge_shards.py.")
        else:
            print(f"[Rank {rank}] Shard saved to {OUT_DIR}.")

if __name__ == "__main__":
    main()
